"""Support for esphome entities."""
from __future__ import annotations

from collections.abc import Callable
import functools
import math
from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast

from aioesphomeapi import (
    EntityCategory as EsphomeEntityCategory,
    EntityInfo,
    EntityState,
    build_unique_id,
)
import voluptuous as vol

from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EntityCategory
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import entity_platform
import homeassistant.helpers.config_validation as cv
import homeassistant.helpers.device_registry as dr
from homeassistant.helpers.device_registry import DeviceInfo
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.entity_platform import AddEntitiesCallback

from .domain_data import DomainData

# Import config flow so that it's added to the registry
from .entry_data import RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper

_R = TypeVar("_R")
_InfoT = TypeVar("_InfoT", bound=EntityInfo)
_EntityT = TypeVar("_EntityT", bound="EsphomeEntity[Any,Any]")
_StateT = TypeVar("_StateT", bound=EntityState)


async def platform_async_setup_entry(
    hass: HomeAssistant,
    entry: ConfigEntry,
    async_add_entities: AddEntitiesCallback,
    *,
    info_type: type[_InfoT],
    entity_type: type[_EntityT],
    state_type: type[_StateT],
) -> None:
    """Set up an esphome platform.

    This method is in charge of receiving, distributing and storing
    info and state updates.
    """
    entry_data: RuntimeEntryData = DomainData.get(hass).get_entry_data(entry)
    entry_data.info[info_type] = {}
    entry_data.state.setdefault(state_type, {})
    platform = entity_platform.async_get_current_platform()

    @callback
    def async_list_entities(infos: list[EntityInfo]) -> None:
        """Update entities of this platform when entities are listed."""
        current_infos = entry_data.info[info_type]
        new_infos: dict[int, EntityInfo] = {}
        add_entities: list[_EntityT] = []

        for info in infos:
            if not current_infos.pop(info.key, None):
                # Create new entity
                entity = entity_type(entry_data, platform.domain, info, state_type)
                add_entities.append(entity)
            new_infos[info.key] = info

        # Anything still in current_infos is now gone
        if current_infos:
            hass.async_create_task(
                entry_data.async_remove_entities(current_infos.values())
            )

        # Then update the actual info
        entry_data.info[info_type] = new_infos

        if new_infos:
            entry_data.async_update_entity_infos(new_infos.values())

        if add_entities:
            # Add entities to Home Assistant
            async_add_entities(add_entities)

    entry_data.cleanup_callbacks.append(
        entry_data.async_register_static_info_callback(info_type, async_list_entities)
    )


def esphome_state_property(
    func: Callable[[_EntityT], _R]
) -> Callable[[_EntityT], _R | None]:
    """Wrap a state property of an esphome entity.

    This checks if the state object in the entity is set, and
    prevents writing NAN values to the Home Assistant state machine.
    """

    @functools.wraps(func)
    def _wrapper(self: _EntityT) -> _R | None:
        # pylint: disable-next=protected-access
        if not self._has_state:
            return None
        val = func(self)
        if isinstance(val, float) and not math.isfinite(val):
            # Home Assistant doesn't use NaN or inf values in state machine
            # (not JSON serializable)
            return None
        return val

    return _wrapper


ICON_SCHEMA = vol.Schema(cv.icon)


ENTITY_CATEGORIES: EsphomeEnumMapper[
    EsphomeEntityCategory, EntityCategory | None
] = EsphomeEnumMapper(
    {
        EsphomeEntityCategory.NONE: None,
        EsphomeEntityCategory.CONFIG: EntityCategory.CONFIG,
        EsphomeEntityCategory.DIAGNOSTIC: EntityCategory.DIAGNOSTIC,
    }
)


class EsphomeEntity(Entity, Generic[_InfoT, _StateT]):
    """Define a base esphome entity."""

    _attr_should_poll = False
    _static_info: _InfoT
    _state: _StateT
    _has_state: bool

    def __init__(
        self,
        entry_data: RuntimeEntryData,
        domain: str,
        entity_info: EntityInfo,
        state_type: type[_StateT],
    ) -> None:
        """Initialize."""

        self._entry_data = entry_data
        self._on_entry_data_changed()
        self._key = entity_info.key
        self._state_type = state_type
        self._on_static_info_update(entity_info)
        assert entry_data.device_info is not None
        device_info = entry_data.device_info
        self._device_info = device_info
        self._attr_device_info = DeviceInfo(
            connections={(dr.CONNECTION_NETWORK_MAC, device_info.mac_address)}
        )
        self._entry_id = entry_data.entry_id
        #
        # If `friendly_name` is set, we use the Friendly naming rules, if
        # `friendly_name` is not set we make an exception to the naming rules for
        # backwards compatibility and use the Legacy naming rules.
        #
        # Friendly naming
        # - Friendly name is prepended to entity names
        # - Device Name is prepended to entity ids
        # - Entity id is constructed from device name and object id
        #
        # Legacy naming
        # - Device name is not prepended to entity names
        # - Device name is not prepended to entity ids
        # - Entity id is constructed from entity name
        #
        if not device_info.friendly_name:
            return
        self._attr_has_entity_name = True
        self.entity_id = f"{domain}.{device_info.name}_{entity_info.object_id}"

    async def async_added_to_hass(self) -> None:
        """Register callbacks."""
        entry_data = self._entry_data
        hass = self.hass
        key = self._key

        self.async_on_remove(
            entry_data.async_register_key_static_info_remove_callback(
                self._static_info,
                functools.partial(self.async_remove, force_remove=True),
            )
        )
        self.async_on_remove(
            async_dispatcher_connect(
                hass,
                entry_data.signal_device_updated,
                self._on_device_update,
            )
        )
        self.async_on_remove(
            entry_data.async_subscribe_state_update(
                self._state_type, key, self._on_state_update
            )
        )
        self.async_on_remove(
            entry_data.async_register_key_static_info_updated_callback(
                self._static_info, self._on_static_info_update
            )
        )
        self._update_state_from_entry_data()

    @callback
    def _on_static_info_update(self, static_info: EntityInfo) -> None:
        """Save the static info for this entity when it changes.

        This method can be overridden in child classes to know
        when the static info changes.
        """
        device_info = self._entry_data.device_info
        if TYPE_CHECKING:
            static_info = cast(_InfoT, static_info)
            assert device_info
        self._static_info = static_info
        self._attr_unique_id = build_unique_id(device_info.mac_address, static_info)
        self._attr_entity_registry_enabled_default = not static_info.disabled_by_default
        self._attr_name = static_info.name
        if entity_category := static_info.entity_category:
            self._attr_entity_category = ENTITY_CATEGORIES.from_esphome(entity_category)
        else:
            self._attr_entity_category = None
        if icon := static_info.icon:
            self._attr_icon = cast(str, ICON_SCHEMA(icon))
        else:
            self._attr_icon = None

    @callback
    def _update_state_from_entry_data(self) -> None:
        """Update state from entry data."""

        state = self._entry_data.state
        key = self._key
        state_type = self._state_type
        has_state = key in state[state_type]
        if has_state:
            self._state = cast(_StateT, state[state_type][key])
        self._has_state = has_state

    @callback
    def _on_state_update(self) -> None:
        """Call when state changed.

        Behavior can be changed in child classes
        """
        self._update_state_from_entry_data()
        self.async_write_ha_state()

    @callback
    def _on_entry_data_changed(self) -> None:
        entry_data = self._entry_data
        self._api_version = entry_data.api_version
        self._client = entry_data.client

    @callback
    def _on_device_update(self) -> None:
        """Call when device updates or entry data changes."""
        self._on_entry_data_changed()
        if not self._entry_data.available:
            # Only write state if the device has gone unavailable
            # since _on_state_update will be called if the device
            # is available when the full state arrives
            # through the next entity state packet.
            self.async_write_ha_state()

    @property
    def available(self) -> bool:
        """Return if the entity is available."""
        if self._device_info.has_deep_sleep:
            # During deep sleep the ESP will not be connectable (by design)
            # For these cases, show it as available
            return self._entry_data.expected_disconnect

        return self._entry_data.available


class EsphomeAssistEntity(Entity):
    """Define a base entity for Assist Pipeline entities."""

    _attr_has_entity_name = True
    _attr_should_poll = False

    def __init__(self, entry_data: RuntimeEntryData) -> None:
        """Initialize the binary sensor."""
        self._entry_data: RuntimeEntryData = entry_data
        assert entry_data.device_info is not None
        device_info = entry_data.device_info
        self._device_info = device_info
        self._attr_unique_id = (
            f"{device_info.mac_address}-{self.entity_description.key}"
        )
        self._attr_device_info = DeviceInfo(
            connections={(dr.CONNECTION_NETWORK_MAC, device_info.mac_address)}
        )

    @callback
    def _update(self) -> None:
        self.async_write_ha_state()

    async def async_added_to_hass(self) -> None:
        """Register update callback."""
        await super().async_added_to_hass()
        self.async_on_remove(
            self._entry_data.async_subscribe_assist_pipeline_update(self._update)
        )
